{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Matrix multiplication from foundations"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The *foundations* we'll assume throughout this course are:\n",
    "\n",
    "- Python\n",
    "- Python modules (non-DL)\n",
    "- pytorch indexable tensor, and tensor creation (including RNGs - random number generators)\n",
    "- fastai.datasets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Check imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[Jump_to lesson 8 video](https://course19.fast.ai/videos/?lesson=8&t=1850)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from exp.nb_00 import *\n",
    "import operator\n",
    "\n",
    "def test(a,b,cmp,cname=None):\n",
    "    if cname is None: cname=cmp.__name__\n",
    "    assert cmp(a,b),f\"{cname}:\\n{a}\\n{b}\"\n",
    "\n",
    "def test_eq(a,b): test(a,b,operator.eq,'==')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(TEST,'test')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# To run tests in console:\n",
    "# ! python run_notebook.py 01_matmul.ipynb"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[Jump_to lesson 8 video](https://course19.fast.ai/videos/?lesson=8&t=2159)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from pathlib import Path\n",
    "from IPython.core.debugger import set_trace\n",
    "from fastai import datasets\n",
    "import pickle, gzip, math, torch, matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "from torch import tensor\n",
    "\n",
    "MNIST_URL='http://deeplearning.net/data/mnist/mnist.pkl'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PosixPath('/home/ubuntu/.fastai/data/mnist.pkl.gz')"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "path = datasets.download_data(MNIST_URL, ext='.gz'); path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with gzip.open(path, 'rb') as f:\n",
    "    ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[0., 0., 0.,  ..., 0., 0., 0.],\n",
       "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "         ...,\n",
       "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "         [0., 0., 0.,  ..., 0., 0., 0.]]),\n",
       " torch.Size([50000, 784]),\n",
       " tensor([5, 0, 4,  ..., 8, 4, 8]),\n",
       " torch.Size([50000]),\n",
       " tensor(0),\n",
       " tensor(9))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_train,y_train,x_valid,y_valid = map(tensor, (x_train,y_train,x_valid,y_valid))\n",
    "n,c = x_train.shape\n",
    "x_train, x_train.shape, y_train, y_train.shape, y_train.min(), y_train.max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert n==y_train.shape[0]==50000\n",
    "test_eq(c,28*28)\n",
    "test_eq(y_train.min(),0)\n",
    "test_eq(y_train.max(),9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mpl.rcParams['image.cmap'] = 'gray'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "img = x_train[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'torch.FloatTensor'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "img.view(28,28).type()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAADgpJREFUeJzt3X+MVfWZx/HPs1j+kKI4aQRCYSnEYJW4082IjSWrxkzVDQZHrekkJjQapn8wiU02ZA3/VNNgyCrslmiamaZYSFpKE3VB0iw0otLGZuKIWC0srTFsO3IDNTjywx9kmGf/mEMzxbnfe+fec++5zPN+JeT+eM6558kNnznn3O+592vuLgDx/EPRDQAoBuEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxDUZc3cmJlxOSHQYO5u1SxX157fzO40syNm9q6ZPVrPawFoLqv12n4zmybpj5I6JQ1Jel1St7sfSqzDnh9osGbs+ZdJetfd33P3c5J+IWllHa8HoInqCf88SX8Z93goe+7vmFmPmQ2a2WAd2wKQs3o+8Jvo0OJzh/Xu3i+pX+KwH2gl9ez5hyTNH/f4y5KO1dcOgGapJ/yvS7rGzL5iZtMlfVvSrnzaAtBoNR/2u/uImfVK2iNpmqQt7v6H3DoD0FA1D/XVtDHO+YGGa8pFPgAuXYQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8EVfMU3ZJkZkclnZZ0XtKIu3fk0RTyM23atGT9yiuvbOj2e3t7y9Yuv/zy5LpLlixJ1tesWZOsP/XUU2Vr3d3dyXU//fTTZH3Dhg3J+uOPP56st4K6wp+5zd0/yOF1ADQRh/1AUPWG3yXtNbM3zKwnj4YANEe9h/3fcPdjZna1pF+b2f+6+/7xC2R/FPjDALSYuvb87n4suz0h6QVJyyZYpt/dO/gwEGgtNYffzGaY2cwL9yV9U9I7eTUGoLHqOeyfLekFM7vwOj939//JpSsADVdz+N39PUn/lGMvU9aCBQuS9enTpyfrN998c7K+fPnysrVZs2Yl173vvvuS9SINDQ0l65s3b07Wu7q6ytZOnz6dXPett95K1l999dVk/VLAUB8QFOEHgiL8QFCEHwiK8ANBEX4gKHP35m3MrHkba6L29vZkfd++fcl6o79W26pGR0eT9YceeihZP3PmTM3bLpVKyfqHH36YrB85cqTmbTeau1s1y7HnB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgGOfPQVtbW7I+MDCQrC9atCjPdnJVqffh4eFk/bbbbitbO3fuXHLdqNc/1ItxfgBJhB8IivADQRF+ICjCDwRF+IGgCD8QVB6z9IZ38uTJZH3t2rXJ+ooVK5L1N998M1mv9BPWKQcPHkzWOzs7k/WzZ88m69dff33Z2iOPPJJcF43Fnh8IivADQRF+ICjCDwRF+IGgCD8QFOEHgqr4fX4z2yJphaQT7r40e65N0g5JCyUdlfSAu6d/6FxT9/v89briiiuS9UrTSff19ZWtPfzww8l1H3zwwWR9+/btyTpaT57f5/+ppDsveu5RSS+5+zWSXsoeA7iEVAy/u++XdPElbCslbc3ub5V0T859AWiwWs/5Z7t7SZKy26vzawlAMzT82n4z65HU0+jtAJicWvf8x81sriRltyfKLeju/e7e4e4dNW4LQAPUGv5dklZl91dJ2plPOwCapWL4zWy7pN9JWmJmQ2b2sKQNkjrN7E+SOrPHAC4hFc/53b27TOn2nHsJ69SpU3Wt/9FHH9W87urVq5P1HTt2JOujo6M1bxvF4go/ICjCDwRF+IGgCD8QFOEHgiL8QFBM0T0FzJgxo2ztxRdfTK57yy23JOt33XVXsr53795kHc3HFN0Akgg/EBThB4Ii/EBQhB8IivADQRF+ICjG+ae4xYsXJ+sHDhxI1oeHh5P1l19+OVkfHBwsW3vmmWeS6zbz/+ZUwjg/gCTCDwRF+IGgCD8QFOEHgiL8QFCEHwiKcf7gurq6kvVnn302WZ85c2bN2163bl2yvm3btmS9VCrVvO2pjHF+AEmEHwiK8ANBEX4gKMIPBEX4gaAIPxBUxXF+M9siaYWkE+6+NHvuMUmrJf01W2ydu/+q4sYY57/kLF26NFnftGlTsn777bXP5N7X15esr1+/Pll///33a972pSzPcf6fSrpzguf/093bs38Vgw+gtVQMv7vvl3SyCb0AaKJ6zvl7zez3ZrbFzK7KrSMATVFr+H8kabGkdkklSRvLLWhmPWY2aGblf8wNQNPVFH53P+7u5919VNKPJS1LLNvv7h3u3lFrkwDyV1P4zWzuuIddkt7Jpx0AzXJZpQXMbLukWyV9ycyGJH1f0q1m1i7JJR2V9N0G9gigAfg+P+oya9asZP3uu+8uW6v0WwFm6eHqffv2JeudnZ3J+lTF9/kBJBF+ICjCDwRF+IGgCD8QFOEHgmKoD4X57LPPkvXLLktfhjIyMpKs33HHHWVrr7zySnLdSxlDfQCSCD8QFOEHgiL8QFCEHwiK8ANBEX4gqIrf50dsN9xwQ7J+//33J+s33nhj2VqlcfxKDh06lKzv37+/rtef6tjzA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQjPNPcUuWLEnWe3t7k/V77703WZ8zZ86ke6rW+fPnk/VSqZSsj46O5tnOlMOeHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCqjjOb2bzJW2TNEfSqKR+d/+hmbVJ2iFpoaSjkh5w9w8b12pclcbSu7u7y9YqjeMvXLiwlpZyMTg4mKyvX78+Wd+1a1ee7YRTzZ5/RNK/uftXJX1d0hozu07So5JecvdrJL2UPQZwiagYfncvufuB7P5pSYclzZO0UtLWbLGtku5pVJMA8jepc34zWyjpa5IGJM1295I09gdC0tV5Nwegcaq+tt/MvijpOUnfc/dTZlVNByYz65HUU1t7ABqlqj2/mX1BY8H/mbs/nz193MzmZvW5kk5MtK6797t7h7t35NEwgHxUDL+N7eJ/Iumwu28aV9olaVV2f5Wknfm3B6BRKk7RbWbLJf1G0tsaG+qTpHUaO+//paQFkv4s6VvufrLCa4Wconv27NnJ+nXXXZesP/3008n6tddeO+me8jIwMJCsP/nkk2VrO3em9xd8Jbc21U7RXfGc391/K6nci90+maYAtA6u8AOCIvxAUIQfCIrwA0ERfiAowg8ExU93V6mtra1sra+vL7lue3t7sr5o0aKaesrDa6+9lqxv3LgxWd+zZ0+y/sknn0y6JzQHe34gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCCrMOP9NN92UrK9duzZZX7ZsWdnavHnzauopLx9//HHZ2ubNm5PrPvHEE8n62bNna+oJrY89PxAU4QeCIvxAUIQfCIrwA0ERfiAowg8EFWacv6urq656PQ4dOpSs7969O1kfGRlJ1lPfuR8eHk6ui7jY8wNBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUObu6QXM5kvaJmmOpFFJ/e7+QzN7TNJqSX/NFl3n7r+q8FrpjQGom7tbNctVE/65kua6+wEzmynpDUn3SHpA0hl3f6rapgg/0HjVhr/iFX7uXpJUyu6fNrPDkor96RoAdZvUOb+ZLZT0NUkD2VO9ZvZ7M9tiZleVWafHzAbNbLCuTgHkquJh/98WNPuipFclrXf3581stqQPJLmkH2js1OChCq/BYT/QYLmd80uSmX1B0m5Je9x90wT1hZJ2u/vSCq9D+IEGqzb8FQ/7zcwk/UTS4fHBzz4IvKBL0juTbRJAcar5tH+5pN9IeltjQ32StE5St6R2jR32H5X03ezDwdRrsecHGizXw/68EH6g8XI77AcwNRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCavYU3R9I+r9xj7+UPdeKWrW3Vu1Lorda5dnbP1a7YFO/z/+5jZsNuntHYQ0ktGpvrdqXRG+1Kqo3DvuBoAg/EFTR4e8vePsprdpbq/Yl0VutCumt0HN+AMUpes8PoCCFhN/M7jSzI2b2rpk9WkQP5ZjZUTN728wOFj3FWDYN2gkze2fcc21m9msz+1N2O+E0aQX19piZvZ+9dwfN7F8L6m2+mb1sZofN7A9m9kj2fKHvXaKvQt63ph/2m9k0SX+U1ClpSNLrkrrd/VBTGynDzI5K6nD3wseEzexfJJ2RtO3CbEhm9h+STrr7huwP51Xu/u8t0ttjmuTMzQ3qrdzM0t9Rge9dnjNe56GIPf8ySe+6+3vufk7SLyStLKCPlufu+yWdvOjplZK2Zve3auw/T9OV6a0luHvJ3Q9k909LujCzdKHvXaKvQhQR/nmS/jLu8ZBaa8pvl7TXzN4ws56im5nA7AszI2W3Vxfcz8UqztzcTBfNLN0y710tM17nrYjwTzSbSCsNOXzD3f9Z0l2S1mSHt6jOjyQt1tg0biVJG4tsJptZ+jlJ33P3U0X2Mt4EfRXyvhUR/iFJ88c9/rKkYwX0MSF3P5bdnpD0gsZOU1rJ8QuTpGa3Jwru52/c/bi7n3f3UUk/VoHvXTaz9HOSfubuz2dPF/7eTdRXUe9bEeF/XdI1ZvYVM5su6duSdhXQx+eY2YzsgxiZ2QxJ31TrzT68S9Kq7P4qSTsL7OXvtMrMzeVmllbB712rzXhdyEU+2VDGf0maJmmLu69vehMTMLNFGtvbS2PfePx5kb2Z2XZJt2rsW1/HJX1f0n9L+qWkBZL+LOlb7t70D97K9HarJjlzc4N6Kzez9IAKfO/ynPE6l364wg+IiSv8gKAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8E9f/Ex0YKZYOZcwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.imshow(img.view((28,28)));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initial python model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " [Jump_to lesson 8 video](https://course19.fast.ai/videos/?lesson=8&t=2342)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " [Jump_to lesson 8 video](https://course19.fast.ai/videos/?lesson=8&t=2342)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "weights = torch.randn(784,10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bias = torch.zeros(10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Matrix multiplication"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def matmul(a,b):\n",
    "    ar,ac = a.shape # n_rows * n_cols\n",
    "    br,bc = b.shape\n",
    "    assert ac==br\n",
    "    c = torch.zeros(ar, bc)\n",
    "    for i in range(ar):\n",
    "        for j in range(bc):\n",
    "            for k in range(ac): # or br\n",
    "                c[i,j] += a[i,k] * b[k,j]\n",
    "    return c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m1 = x_valid[:5]\n",
    "m2 = weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([5, 784]), torch.Size([784, 10]))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m1.shape,m2.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 818 ms, sys: 0 ns, total: 818 ms\n",
      "Wall time: 835 ms\n"
     ]
    }
   ],
   "source": [
    "%time t1=matmul(m1, m2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([5, 10])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t1.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is kinda slow - what if we could speed it up by 50,000 times? Let's try!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "50000"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(x_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Elementwise ops"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Operators (+,-,\\*,/,>,<,==) are usually element-wise.\n",
    "\n",
    "Examples of element-wise operations:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " [Jump_to lesson 8 video](https://course19.fast.ai/videos/?lesson=8&t=2682)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([10.,  6., -4.]), tensor([2., 8., 7.]))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = tensor([10., 6, -4])\n",
    "b = tensor([2., 8, 7])\n",
    "a,b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([12., 14.,  3.])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a + b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.6667)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(a < b).float().mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1., 2., 3.],\n",
       "        [4., 5., 6.],\n",
       "        [7., 8., 9.]])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m = tensor([[1., 2, 3], [4,5,6], [7,8,9]]); m"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Frobenius norm:\n",
    "\n",
    "$$\\| A \\|_F = \\left( \\sum_{i,j=1}^n | a_{ij} |^2 \\right)^{1/2}$$\n",
    "\n",
    "*Hint*: you don't normally need to write equations in LaTeX yourself, instead, you can click 'edit' in Wikipedia and copy the LaTeX from there (which is what I did for the above equation). Or on arxiv.org, click \"Download: Other formats\" in the top right, then \"Download source\"; rename the downloaded file to end in `.tgz` if it doesn't already, and you should find the source there, including the equations to copy and paste."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(16.8819)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(m*m).sum().sqrt()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Elementwise matmul"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def matmul(a,b):\n",
    "    ar,ac = a.shape\n",
    "    br,bc = b.shape\n",
    "    assert ac==br\n",
    "    c = torch.zeros(ar, bc)\n",
    "    for i in range(ar):\n",
    "        for j in range(bc):\n",
    "            # Any trailing \",:\" can be removed\n",
    "            c[i,j] = (a[i,:] * b[:,j]).sum()\n",
    "    return c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.39 ms ± 70 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
     ]
    }
   ],
   "source": [
    "%timeit -n 10 _=matmul(m1, m2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "178.02"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "890.1/5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def near(a,b): return torch.allclose(a, b, rtol=1e-3, atol=1e-5)\n",
    "def test_near(a,b): test(a,b,near)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_near(t1,matmul(m1, m2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Broadcasting"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The term **broadcasting** describes how arrays with different shapes are treated during arithmetic operations.  The term broadcasting was first used by Numpy.\n",
    "\n",
    "From the [Numpy Documentation](https://docs.scipy.org/doc/numpy-1.10.0/user/basics.broadcasting.html):\n",
    "\n",
    "    The term broadcasting describes how numpy treats arrays with \n",
    "    different shapes during arithmetic operations. Subject to certain \n",
    "    constraints, the smaller array is “broadcast” across the larger \n",
    "    array so that they have compatible shapes. Broadcasting provides a \n",
    "    means of vectorizing array operations so that looping occurs in C\n",
    "    instead of Python. It does this without making needless copies of \n",
    "    data and usually leads to efficient algorithm implementations.\n",
    "    \n",
    "In addition to the efficiency of broadcasting, it allows developers to write less code, which typically leads to fewer errors.\n",
    "\n",
    "*This section was adapted from [Chapter 4](http://nbviewer.jupyter.org/github/fastai/numerical-linear-algebra/blob/master/nbs/4.%20Compressed%20Sensing%20of%20CT%20Scans%20with%20Robust%20Regression.ipynb#4.-Compressed-Sensing-of-CT-Scans-with-Robust-Regression) of the fast.ai [Computational Linear Algebra](https://github.com/fastai/numerical-linear-algebra) course.*"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[Jump_to lesson 8 video](https://course19.fast.ai/videos/?lesson=8&t=3110)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Broadcasting with a scalar"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([10.,  6., -4.])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([1, 1, 0], dtype=torch.uint8)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a > 0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "How are we able to do a > 0?  0 is being **broadcast** to have the same dimensions as a.\n",
    "\n",
    "For instance you can normalize our dataset by subtracting the mean (a scalar) from the entire data set (a matrix) and dividing by the standard deviation (another scalar), using broadcasting.\n",
    "\n",
    "Other examples of broadcasting with a scalar:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([11.,  7., -3.])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a + 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1., 2., 3.],\n",
       "        [4., 5., 6.],\n",
       "        [7., 8., 9.]])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 2.,  4.,  6.],\n",
       "        [ 8., 10., 12.],\n",
       "        [14., 16., 18.]])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "2*m"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Broadcasting a vector to a matrix"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can also broadcast a vector to a matrix:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([10., 20., 30.])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "c = tensor([10.,20,30]); c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1., 2., 3.],\n",
       "        [4., 5., 6.],\n",
       "        [7., 8., 9.]])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([3, 3]), torch.Size([3]))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m.shape,c.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[11., 22., 33.],\n",
       "        [14., 25., 36.],\n",
       "        [17., 28., 39.]])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m + c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[11., 22., 33.],\n",
       "        [14., 25., 36.],\n",
       "        [17., 28., 39.]])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "c + m"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We don't really copy the rows, but it looks as if we did. In fact, the rows are given a *stride* of 0."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = c.expand_as(m)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[10., 20., 30.],\n",
       "        [10., 20., 30.],\n",
       "        [10., 20., 30.]])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[11., 22., 33.],\n",
       "        [14., 25., 36.],\n",
       "        [17., 28., 39.]])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m + t"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       " 10.0\n",
       " 20.0\n",
       " 30.0\n",
       "[torch.FloatStorage of size 3]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t.storage()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((0, 1), torch.Size([3, 3]))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t.stride(), t.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can index with the special value [None] or use `unsqueeze()` to convert a 1-dimensional array into a 2-dimensional array (although one of those dimensions has value 1)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[10., 20., 30.]])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "c.unsqueeze(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[10.],\n",
       "        [20.],\n",
       "        [30.]])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "c.unsqueeze(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1., 2., 3.],\n",
       "        [4., 5., 6.],\n",
       "        [7., 8., 9.]])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([3]), torch.Size([1, 3]), torch.Size([3, 1]))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "c.shape, c.unsqueeze(0).shape,c.unsqueeze(1).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([3]), torch.Size([1, 3]), torch.Size([3, 1]))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "c.shape, c[None].shape,c[:,None].shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can always skip trailling ':'s. And '...' means '*all preceding dimensions*'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([1, 3]), torch.Size([3, 1]))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "c[None].shape,c[...,None].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[10., 10., 10.],\n",
       "        [20., 20., 20.],\n",
       "        [30., 30., 30.]])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "c[:,None].expand_as(m)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[11., 12., 13.],\n",
       "        [24., 25., 26.],\n",
       "        [37., 38., 39.]])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m + c[:,None]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[10.],\n",
       "        [20.],\n",
       "        [30.]])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "c[:,None]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Matmul with broadcasting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def matmul(a,b):\n",
    "    ar,ac = a.shape\n",
    "    br,bc = b.shape\n",
    "    assert ac==br\n",
    "    c = torch.zeros(ar, bc)\n",
    "    for i in range(ar):\n",
    "#       c[i,j] = (a[i,:]          * b[:,j]).sum() # previous\n",
    "        c[i]   = (a[i  ].unsqueeze(-1) * b).sum(dim=0)\n",
    "    return c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "254 µs ± 11.9 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
     ]
    }
   ],
   "source": [
    "%timeit -n 10 _=matmul(m1, m2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3194.945848375451"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "885000/277"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_near(t1, matmul(m1, m2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Broadcasting Rules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[10., 20., 30.]])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "c[None,:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 3])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "c[None,:].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[10.],\n",
       "        [20.],\n",
       "        [30.]])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "c[:,None]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([3, 1])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "c[:,None].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[100., 200., 300.],\n",
       "        [200., 400., 600.],\n",
       "        [300., 600., 900.]])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "c[None,:] * c[:,None]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0, 1, 1],\n",
       "        [0, 0, 1],\n",
       "        [0, 0, 0]], dtype=torch.uint8)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "c[None] > c[:,None]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "When operating on two arrays/tensors, Numpy/PyTorch compares their shapes element-wise. It starts with the **trailing dimensions**, and works its way forward. Two dimensions are **compatible** when\n",
    "\n",
    "- they are equal, or\n",
    "- one of them is 1, in which case that dimension is broadcasted to make it the same size\n",
    "\n",
    "Arrays do not need to have the same number of dimensions. For example, if you have a `256*256*3` array of RGB values, and you want to scale each color in the image by a different value, you can multiply the image by a one-dimensional array with 3 values. Lining up the sizes of the trailing axes of these arrays according to the broadcast rules, shows that they are compatible:\n",
    "\n",
    "    Image  (3d array): 256 x 256 x 3\n",
    "    Scale  (1d array):             3\n",
    "    Result (3d array): 256 x 256 x 3\n",
    "\n",
    "The [numpy documentation](https://docs.scipy.org/doc/numpy-1.13.0/user/basics.broadcasting.html#general-broadcasting-rules) includes several examples of what dimensions can and can not be broadcast together."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Einstein summation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Einstein summation (`einsum`) is a compact representation for combining products and sums in a general way. From the numpy docs:\n",
    "\n",
    "\"The subscripts string is a comma-separated list of subscript labels, where each label refers to a dimension of the corresponding operand. Whenever a label is repeated it is summed, so `np.einsum('i,i', a, b)` is equivalent to `np.inner(a,b)`. If a label appears only once, it is not summed, so `np.einsum('i', a)` produces a view of a with no changes.\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[Jump_to lesson 8 video](https://course19.fast.ai/videos/?lesson=8&t=4280)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# c[i,j] += a[i,k] * b[k,j]\n",
    "# c[i,j] = (a[i,:] * b[:,j]).sum()\n",
    "def matmul(a,b): return torch.einsum('ik,kj->ij', a, b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "57.2 µs ± 15.3 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
     ]
    }
   ],
   "source": [
    "%timeit -n 10 _=matmul(m1, m2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "16090.90909090909"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "885000/55"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_near(t1, matmul(m1, m2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### pytorch op"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can use pytorch's function or operator directly for matrix multiplication."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[Jump_to lesson 8 video](https://course19.fast.ai/videos/?lesson=8&t=4702)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "18.2 µs ± 6.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
     ]
    }
   ],
   "source": [
    "%timeit -n 10 t2 = m1.matmul(m2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "49166.666666666664"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# time comparison vs pure python:\n",
    "885000/18"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t2 = m1@m2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_near(t1, t2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([5, 784]), torch.Size([784, 10]))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m1.shape,m2.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Converted 01_matmul.ipynb to nb_01.py\r\n"
     ]
    }
   ],
   "source": [
    "!python notebook2script.py 01_matmul.ipynb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
